[AMD] Enable preshuffle paged MQA and page_size=64 for NSA indexer#23562
Open
1am9trash wants to merge 12 commits intosgl-project:mainfrom
Open
[AMD] Enable preshuffle paged MQA and page_size=64 for NSA indexer#235621am9trash wants to merge 12 commits intosgl-project:mainfrom
1am9trash wants to merge 12 commits intosgl-project:mainfrom
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces a 'preshuffle' layout for DeepSeek DSA on HIP, integrating AITER-based kernels for gathering and storing K/S data and updating the default page size to 64. A significant issue was identified where falling back to Triton kernels on HIP (when AITER is disabled) would result in a layout mismatch, as these kernels do not yet support the preshuffle layout required by the MQA kernel.
This was referenced Apr 23, 2026
Jacob0226
added a commit
to Jacob0226/sglang
that referenced
this pull request
Apr 29, 2026
Squashed implementation of three HIP-only optimizations that together shrink the GLM-5-FP8 NSA tilelang decode layer on MI355X from ~397 us to ~324 us (-73 us / -18.4%, MI355X TP=8 fp8 KV cache). ============================================================================== 1. fix(rocm): restore `_is_hip` in DeepseekV2Model.alt_stream creation ============================================================================== Commit a1ceb2e ("[AMD] Enable MoE dual stream overlap on HIP for GLM4/GLM5") added `_is_hip` to the alt_stream gate. The MUSA backend PR b35213b ("[MUSA][16/N] Add MUSA backend support for layers and DeepSeek models") was branched off a parent that did not contain a1ceb2e, and on merge inadvertently dropped `_is_hip` while adding `_is_musa`. Result: on ROCm `self.alt_stream is None`, so `forward_normal_dual_stream` and the MLA dual-stream fork are never entered — decode traces show only one physical stream. This commit restores `_is_hip` alongside `_is_musa` and re-applies the `not _use_aiter` guard in `forward_normal_dual_stream`'s routed_scaling_factor multiply (aiter's biased_grouped_topk already fuses the scaling, so multiplying again would double it). Both changes are HIP-only: CUDA / MUSA / NPU branches are unaffected. ============================================================================== 2. perf(rocm-nsa): A_v4 dual-stream layout in forward_absorb_prepare ============================================================================== Refactor the q_b_proj / NSA-indexer dual-stream fork in DeepseekMLAForwardMixin.forward_absorb_prepare so that on HIP the indexer chain on alt overlaps not just with q_b_proj but also with the gap-fill that follows on cur (bmm w_kc absorb + rotary_emb on q_pe/k_pe, plus fused_qk_rope_cat_and_cache_mla on the gfx95 NSA tilelang path). Two HIP-graph capture rules drive the layout (validated by the microbenchmark in SGLang-benchmarks/tools/glm5_proposalA_v3_test.py variant A_v4: -18.9 us/layer over the prior layout): 1. Dispatch order picks the physical stream — the branch dispatched first at the fork keeps the predecessor stream (phys 0); the later-dispatched branch lands on a fresh aux stream (phys 4). We dispatch q_b_proj on cur FIRST and only afterwards enter `with stream(alt):` for the indexer. 2. `alt.wait_stream(cur)` snapshots cur's state at call time. Since the indexer needs only q_lora (phase1 output), placing wait_stream BEFORE q_b_proj lets alt's heavy indexer chain start the instant phase1 completes — in parallel with cur's q_b_proj plus gap-fill, instead of waiting for q_b_proj first. The `cur.wait_stream(alt)` join is moved past rotary_emb so cur's gap-fill chain overlaps with alt's indexer. CUDA / MUSA / NPU paths are gated to keep the original PR sgl-project#23562 layout (byte-identical) — these were not validated under the new schedule. Drives `overlap_indexer_with_gap_fill` flag used by sub-optimization (3). ============================================================================== 3. perf(rocm-nsa): pull fused_qk_rope_cat_and_cache_mla into the dual-stream window, and skip the redundant CatArrayBatchedCopy that follows attn_mqa ============================================================================== For the gfx95 NSA tilelang fused-rope path, the `fused_qk_rope_cat_and_cache_mla` kernel that normally runs in `forward_absorb_core` is moved into `forward_absorb_prepare` so it runs on cur inside the dual-stream window — overlapping with the alt indexer instead of running serially after the join. The result is forwarded from prepare to core via a new optional `fused_qk_kv_cache` return field; core falls back to the original inline computation when the prepare-side fast path was not taken (non-capture, non-decode, or non-HIP). In addition, `forward_absorb_core` now passes the already-concatenated `q_cat` directly to `attn_mqa` with `q_rope=None` on the decode path (prefill keeps the split form because `nsa_backend.forward_extend` asserts `q_rope is not None`). On the receiving side, `nsa_backend.forward_decode` is updated to track `q_all` explicitly: - When caller passes split q_nope / q_rope (CUDA / non-HIP paths or non-decode HIP), q_all is initialized to None and each impl block re-cats as before — byte-identical to the pre-patch behavior. - When caller passes q_rope=None on HIP decode, q_all is set to a zero-copy `q.contiguous().view(...)` of `q_cat` and each impl block skips the otherwise-redundant `concat_mla_absorb_q_general` call. The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP backends always re-cat (preserves prior behavior bit-exactly). This eliminates the CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel that previously fired once per layer per decode step (~5 us/layer) between fused_qk_rope_cat and main_kernel on ROCm tilelang traces: 390 invocations → 0 in DualStream0429_v2 trace. ============================================================================== Validation ============================================================================== * MI355X TP=8 GLM-5.1-FP8 fp8 KV cache, NSA tilelang decode: - Layer latency: ~397 us → ~324 us (-73 us / -18.4%) - 8k1k conc4 TPOT: 24.48 ms median (output throughput 117 tok/s) - GSM8K 1200q: 0.953 (PR sgl-project#23562 baseline 0.951) * trace: results/.../GLM-5.1-FP8-prof-DualStream0429_v2/ prof_in8192_out1024_conc4_p8/*-TP-0-DECODE.trace.json.gz * Stacks on top of sgl-project#23562 (preshuffled paged MQA + page_size=64) and requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer k-cache kernels). ============================================================================== Files ============================================================================== * deepseek_v2.py (+5 -2) alt_stream gate + routed_scaling guard * forward_mla.py (+212 -73) A_v4 layout + fused pull-up + cat-skip plumbing, HIP-only via `_is_hip` gate * nsa_backend.py (+15 -4) q_all tracking + cat-skip, HIP-only
Jacob0226
added a commit
to Jacob0226/sglang
that referenced
this pull request
Apr 29, 2026
…ual-stream This commit lands two HIP-only optimizations on top of PR sgl-project#23562: 1. Cat-skip in nsa_backend.forward_decode (default ON, ~2.6 us / layer) 2. A_v4 NSA dual-stream layout (gated OFF by default — regresses on MI355X) Validated on MI355X TP=8 GLM-5.1-FP8 (8k1k conc4): Variant Median TPOT Δ vs Thomas --------------------------------------------------------------------- Thomas (PR sgl-project#23562 only) 21.21 ms baseline This commit, default (cat-skip on, dual-stream off) 20.48 ms −3.4% (faster) This commit + SGLANG_ENABLE_HIP_DUAL_STREAM=1 + --disable-shared-experts-fusion 24.45 ms +15.3% (regression) ============================================================================== 1. Cat-skip optimization (default ON, HIP-only) ============================================================================== In the NSA TileLang fused-rope decode path, fused_qk_rope_cat_and_cache_mla produces a contiguous `q_cat` tensor of shape (M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch flow then sliced q_cat into (q_nope_fused, q_pe_fused) and passed them as separate args to attn_mqa, which causes nsa_backend.forward_decode to call concat_mla_absorb_q_general(q_nope, q_rope) to rebuild q_all. On ROCm that fallback hits torch.cat → CatArrayBatchedCopy, producing a tensor that is byte-identical to the q_cat we already have. forward_absorb_core now passes q_cat directly to attn_mqa with q_rope=None on the decode path (prefill keeps the split form because forward_extend asserts q_rope is not None). nsa_backend.forward_decode is updated to track q_all explicitly: - When caller passes split q_nope / q_rope, q_all=None and each impl block re-cats as before — byte-identical to pre-patch behavior. - When caller passes q_rope=None on HIP decode, q_all is set to a zero-copy `q.contiguous().view(...)` and the cat is skipped. The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP backends always re-cat (preserves CUDA / MUSA behavior bit-exactly). Effect: CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel that previously fired once per layer per decode step disappears from ROCm tilelang traces. ============================================================================== 2. A_v4 dual-stream layout (opt-in via SGLANG_ENABLE_HIP_DUAL_STREAM=1) ============================================================================== forward_absorb_prepare gains a HIP-only A_v4 dual-stream layout that overlaps the NSA indexer chain on alt with [q_b_proj + bmm w_kc + fused_qk_rope_cat] on cur. Two HIP-graph capture rules drive the layout: 1. Dispatch order picks the physical stream — the branch dispatched first keeps the predecessor stream (phys 0); the later-dispatched branch lands on a fresh aux stream (phys 4). q_b_proj is dispatched on cur FIRST, then `with stream(alt):` for the indexer. 2. alt.wait_stream(cur) is placed BEFORE q_b_proj. Indexer needs only q_lora (phase1 output), not q_b_proj's q, so alt's heavy indexer chain can start the moment phase1 completes — in parallel with cur's q_b_proj plus gap-fill. The cur.wait_stream(alt) join is moved past rotary_emb so cur's gap-fill chain overlaps with alt's indexer. fused_qk_rope_cat_and_cache_mla is also pulled from forward_absorb_core into prepare's dual-stream window, with the result forwarded via a new optional fused_qk_kv_cache return field. CUDA / MUSA / NPU paths take the original q_b_proj ∥ NSA-indexer layout from PR sgl-project#23562 base (byte-identical) — the new layout was not validated on those platforms. Why opt-in: on MI355X the layout regresses ~30 us / layer due to three contention sources: - HBM bandwidth contention: indexer's memory-bound kernels lose 0.5-2.4 us each when sharing HBM with cur GEMMs (+8 us total). - Compute-unit split: scheduler partitions 256 CUs across concurrent kernels, slowing both compute-bound kernels (+5 us total). - HIP-graph AllReduce slowdown: aiter::cross_device_reduce_1stage takes 23 us under dual-stream graph capture vs 9.5 us single-stream — same kernel, same TP=8 topology. Likely caused by the AR's first-stage peer fence having to drain alt's KV-cache writes too. ~+26 us / layer (2 ARs). Theoretical A_v4 saving (gap-fill ∥ indexer ≈ −10 us / layer) is dwarfed by these costs. The layout is preserved behind SGLANG_ENABLE_HIP_DUAL_STREAM for future ROCm releases that may fix the AR fence cost. To enable for testing: SGLANG_ENABLE_HIP_DUAL_STREAM=1 ./GLM.sh --dual-stream-rocm ... ============================================================================== Files changed ============================================================================== environ.py (+8) New env var SGLANG_ENABLE_HIP_DUAL_STREAM deepseek_v2.py (+15 -2) alt_stream gate now requires _is_hip + env var. forward_normal_dual_stream's routed_scaling multiply also adds `not _use_aiter` (aiter's biased_grouped_topk already fuses the scaling). forward_mla.py (+212 -73) A_v4 layout in forward_absorb_prepare (gated on _is_hip; degrades to serial when alt_stream is None). fused_qk_rope_cat pull-up + q_rope=None cat-skip plumbing in forward_absorb_core. nsa_backend.py (+15 -4) q_all tracking + cat-skip in forward_decode. HIP-only — non-HIP always re-cats. Stacks on top of PR sgl-project#23562 (preshuffled paged MQA + page_size=64) and requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer k-cache kernels). Detailed regression analysis: ~/SGLang-benchmarks/tmp/dual_stream_regression_analysis.md
Jacob0226
added a commit
to Jacob0226/sglang
that referenced
this pull request
Apr 29, 2026
…ual-stream This commit lands two HIP-only optimizations on top of PR sgl-project#23562: 1. Cat-skip in nsa_backend.forward_decode (default ON, ~2.6 us / layer) 2. A_v4 NSA dual-stream layout (gated OFF by default — regresses on MI355X) Validated on MI355X TP=8 GLM-5.1-FP8 (8k1k conc4): Variant Median TPOT Δ vs Thomas --------------------------------------------------------------------- Thomas (PR sgl-project#23562 only) 21.21 ms baseline This commit, default (cat-skip on, dual-stream off) 20.48 ms −3.4% (faster) This commit + SGLANG_ENABLE_HIP_DUAL_STREAM=1 + --disable-shared-experts-fusion 24.45 ms +15.3% (regression) ============================================================================== 1. Cat-skip optimization (default ON, HIP-only) ============================================================================== In the NSA TileLang fused-rope decode path, fused_qk_rope_cat_and_cache_mla produces a contiguous `q_cat` tensor of shape (M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch flow then sliced q_cat into (q_nope_fused, q_pe_fused) and passed them as separate args to attn_mqa, which causes nsa_backend.forward_decode to call concat_mla_absorb_q_general(q_nope, q_rope) to rebuild q_all. On ROCm that fallback hits torch.cat → CatArrayBatchedCopy, producing a tensor that is byte-identical to the q_cat we already have. forward_absorb_core now passes q_cat directly to attn_mqa with q_rope=None on the decode path (prefill keeps the split form because forward_extend asserts q_rope is not None). nsa_backend.forward_decode is updated to track q_all explicitly: - When caller passes split q_nope / q_rope, q_all=None and each impl block re-cats as before — byte-identical to pre-patch behavior. - When caller passes q_rope=None on HIP decode, q_all is set to a zero-copy `q.contiguous().view(...)` and the cat is skipped. The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP backends always re-cat (preserves CUDA / MUSA behavior bit-exactly). Effect: CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel that previously fired once per layer per decode step disappears from ROCm tilelang traces. ============================================================================== 2. A_v4 dual-stream layout (opt-in via SGLANG_ENABLE_HIP_DUAL_STREAM=1) ============================================================================== forward_absorb_prepare gains a HIP-only A_v4 dual-stream layout that overlaps the NSA indexer chain on alt with [q_b_proj + bmm w_kc + fused_qk_rope_cat] on cur. Two HIP-graph capture rules drive the layout: 1. Dispatch order picks the physical stream — the branch dispatched first keeps the predecessor stream (phys 0); the later-dispatched branch lands on a fresh aux stream (phys 4). q_b_proj is dispatched on cur FIRST, then `with stream(alt):` for the indexer. 2. alt.wait_stream(cur) is placed BEFORE q_b_proj. Indexer needs only q_lora (phase1 output), not q_b_proj's q, so alt's heavy indexer chain can start the moment phase1 completes — in parallel with cur's q_b_proj plus gap-fill. The cur.wait_stream(alt) join is moved past rotary_emb so cur's gap-fill chain overlaps with alt's indexer. fused_qk_rope_cat_and_cache_mla is also pulled from forward_absorb_core into prepare's dual-stream window, with the result forwarded via a new optional fused_qk_kv_cache return field. CUDA / MUSA / NPU paths take the original q_b_proj ∥ NSA-indexer layout from PR sgl-project#23562 base (byte-identical) — the new layout was not validated on those platforms. Why opt-in: on MI355X the layout regresses ~30 us / layer due to three contention sources: - HBM bandwidth contention: indexer's memory-bound kernels lose 0.5-2.4 us each when sharing HBM with cur GEMMs (+8 us total). - Compute-unit split: scheduler partitions 304 CUs across concurrent kernels, slowing both compute-bound kernels (+5 us total). - HIP-graph AllReduce slowdown: aiter::cross_device_reduce_1stage takes 23 us under dual-stream graph capture vs 9.5 us single-stream — same kernel, same TP=8 topology. Likely caused by the AR's first-stage peer fence having to drain alt's KV-cache writes too. ~+26 us / layer (2 ARs). Theoretical A_v4 saving (gap-fill ∥ indexer ≈ −10 us / layer) is dwarfed by these costs. The layout is preserved behind SGLANG_ENABLE_HIP_DUAL_STREAM for future ROCm releases that may fix the AR fence cost. To enable for testing: SGLANG_ENABLE_HIP_DUAL_STREAM=1 ./GLM.sh --dual-stream-rocm ... ============================================================================== Files changed ============================================================================== environ.py (+8) New env var SGLANG_ENABLE_HIP_DUAL_STREAM deepseek_v2.py (+15 -2) alt_stream gate now requires _is_hip + env var. forward_normal_dual_stream's routed_scaling multiply also adds `not _use_aiter` (aiter's biased_grouped_topk already fuses the scaling). forward_mla.py (+212 -73) A_v4 layout in forward_absorb_prepare (gated on _is_hip; degrades to serial when alt_stream is None). fused_qk_rope_cat pull-up + q_rope=None cat-skip plumbing in forward_absorb_core. nsa_backend.py (+15 -4) q_all tracking + cat-skip in forward_decode. HIP-only — non-HIP always re-cats. Stacks on top of PR sgl-project#23562 (preshuffled paged MQA + page_size=64) and requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer k-cache kernels). Detailed regression analysis: ~/SGLang-benchmarks/tmp/dual_stream_regression_analysis.md
5 tasks
Jacob0226
added a commit
to Jacob0226/sglang
that referenced
this pull request
May 6, 2026
GLM-5 NSA TileLang decode on ROCm dispatches a `CatArrayBatchedCopy` kernel
once per layer per decode step that rebuilds an already-existing tensor.
This is a strict-improvement bug fix: ~2.6 us / layer saved, 0 changes for
non-HIP backends.
==============================================================================
Root cause
==============================================================================
For the NSA TileLang fused-rope decode path (`_use_aiter_gfx95 + nsa +
nsa_decode_backend == "tilelang"`), `forward_absorb_core` calls
`fused_qk_rope_cat_and_cache_mla` which produces a contiguous q_cat tensor
of shape (M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch
flow then sliced q_cat into q_nope_fused / q_pe_fused and passed them as
separate args to attn_mqa.
attn_mqa -> NSABackend.forward_decode then takes the if-branch (q_rope
is not None), views the slices, and for tilelang / flashmla_sparse /
flashmla_kv / aiter decode impls calls
`concat_mla_absorb_q_general(q_nope, q_rope)` to rebuild q_all. On ROCm,
that helper falls back to `torch.cat([q_nope, q_rope], dim=-1)`, which
allocates a fresh contiguous tensor and dispatches a copy kernel. The
result is byte-identical to the q_cat we already had — the cat is pure
overhead.
==============================================================================
Fix
==============================================================================
(1) `forward_absorb_core` now passes q_cat directly to attn_mqa with
q_rope=None on the decode path. Prefill (forward_extend) keeps the
split form because `nsa_backend.forward_extend` asserts
`q_rope is not None`.
(2) `nsa_backend.forward_decode` is updated to track q_all explicitly:
- When the caller passes split q_nope / q_rope, q_all is initialized
to None and each impl block re-cats as before (byte-identical to
pre-patch behavior).
- When the caller passes q_rope=None on HIP, q_all is set to a
zero-copy `q.contiguous().view(...)` and the cat is skipped.
The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP
backends always re-cat (preserves CUDA / MUSA paths bit-exactly).
==============================================================================
Validation
==============================================================================
MI355X TP=8 GLM-5.1-FP8 fp8 KV cache, NSA TileLang decode (on top of
PR sgl-project#23562 + aiter PR sgl-project#2879):
scenario | before | after | TPOT Δ
--------------------- | --------- | --------- | --------
8k1k conc4 TPOT | 21.21 ms | 20.76 ms | -2.17%
8k1k conc8 TPOT | 25.28 ms | 24.82 ms | -1.82%
8k1k conc16 TPOT | 30.79 ms | 30.33 ms | -1.49%
8k1k conc32 TPOT | 42.92 ms | 42.46 ms | -1.07%
8k1k conc64 TPOT | 61.79 ms | 61.33 ms | -0.74%
1k1k conc4 TPOT | 18.79 ms | 18.33 ms | -2.45%
1k1k conc8 TPOT | 21.14 ms | 20.66 ms | -2.27%
1k1k conc16 TPOT | 23.63 ms | 23.15 ms | -2.03%
1k1k conc32 TPOT | 29.19 ms | 28.69 ms | -1.71%
1k1k conc64 TPOT | 35.02 ms | 34.60 ms | -1.20%
Output throughput improves by the same percentage on every scenario.
Cat-skip's absolute ~2.6 us / layer benefit is constant; the relative
gain is highest at small batch + short prompt (where total layer time is
smallest) and decays with batch size.
GSM8K accuracy: 0.942 vs 0.951 baseline (within run-to-run variance
observed across multiple runs of the same config: 0.946-0.953).
==============================================================================
Files
==============================================================================
forward_mla.py (+50 -16) forward_absorb_core:_skip_rope_for_nsa_tilelang_fused
branch passes q_cat with q_rope=None for decode.
nsa_backend.py (+12 -4) forward_decode tracks q_all and skips cat on HIP
when caller already provided concatenated q.
5 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
In the NSA indexer, paged mqa computation is very slow at high concurrency. For example, at 8k1kcc64, it takes ~88 us per layer, consisting of two kernels (logits tensor init 11us + MQA kernel 77us).
The main bottleneck is the oversized
block_table. Theblock_tableshape is (batch_size, max_seq_len / page_size). Withpage_size=1andmax_seq_len=131kat cc64, theblock_tableshape is (64, 131072), totaling 64 × 131072 × 4B ≈ 32MB, with frequent indirect loads, causing poor MQA kernel performance.Modifications
To reduce the block_table size, we cannot use
page_size=1. In this PR, we change it to 64. This introduces the following changes:page_size>1. We switch to the MQA preshuffle kernel, which supportspage_sizeas a multiple of 16.preshuffle=True).Other changes:
torch.full(..., -inf)withtorch.empty(...)to eliminate an unnecessary initialization kernel.page_sizeinitialization and corresponding assertions.Accuracy Tests
GLM-5.1-FP8 launch cmd
MI355 GSM8k (TP8): 0.951
Speed Tests and Profiling
Per-layer profiling:
Critical case (8k1kcc64):
Benchmark on MI355X TP8, concurrency 4/8/16/32/64 averaged:
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci